#!/usr/bin/env python

#####################################################################
#
# portstat is a tool for summarizing network statistics.
#
#####################################################################

import argparse
import cPickle as pickle
import datetime
import os.path
import swsssdk
import sys
import time

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from utilities_common.netstat import ns_diff, ns_brate, ns_prate, ns_util, table_as_json
from utilities_common.intf_filter import parse_interface_in_filter

PORT_RATE = 40

NStats = namedtuple("NStats", "rx_ok, rx_err, rx_drop, rx_ovr, tx_ok,\
                    tx_err, tx_drop, tx_ovr, rx_byt, tx_byt")
header_all = ['IFACE', 'STATE', 'RX_OK', 'RX_BPS', 'RX_PPS', 'RX_UTIL', 'RX_ERR', 'RX_DRP', 'RX_OVR',
          'TX_OK', 'TX_BPS', 'Tx_PPS', 'TX_UTIL', 'TX_ERR', 'TX_DRP', 'TX_OVR']
header_std = ['IFACE', 'STATE', 'RX_OK', 'RX_BPS', 'RX_UTIL', 'RX_ERR', 'RX_DRP', 'RX_OVR',
          'TX_OK', 'TX_BPS', 'TX_UTIL', 'TX_ERR', 'TX_DRP', 'TX_OVR']
header_errors_only = ['IFACE', 'STATE', 'RX_ERR', 'RX_DRP', 'RX_OVR', 'TX_ERR', 'TX_DRP', 'TX_OVR']
header_rates_only = ['IFACE', 'STATE', 'RX_OK', 'RX_BPS', 'RX_PPS', 'RX_UTIL', 'TX_OK', 'TX_BPS', 'TX_PPS', 'TX_UTIL']

counter_bucket_dict = {
    'SAI_PORT_STAT_IF_IN_UCAST_PKTS': 0,
    'SAI_PORT_STAT_IF_IN_NON_UCAST_PKTS': 0,
    'SAI_PORT_STAT_IF_IN_ERRORS': 1,
    'SAI_PORT_STAT_IF_IN_DISCARDS': 2,
    'SAI_PORT_STAT_ETHER_RX_OVERSIZE_PKTS': 3,
    'SAI_PORT_STAT_IF_OUT_UCAST_PKTS': 4,
    'SAI_PORT_STAT_IF_OUT_NON_UCAST_PKTS': 4,
    'SAI_PORT_STAT_IF_OUT_ERRORS': 5,
    'SAI_PORT_STAT_IF_OUT_DISCARDS': 6,
    'SAI_PORT_STAT_ETHER_TX_OVERSIZE_PKTS': 7,
    'SAI_PORT_STAT_IF_IN_OCTETS': 8,
    'SAI_PORT_STAT_IF_OUT_OCTETS': 9
}

STATUS_NA = 'N/A'

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"

PORT_STATUS_TABLE_PREFIX = "PORT_TABLE:"
PORT_OPER_STATUS_FIELD = "oper_status"
PORT_ADMIN_STATUS_FIELD = "admin_status"
PORT_STATUS_VALUE_UP = 'UP'
PORT_STATUS_VALUE_DOWN = 'DOWN'
PORT_SPEED_FIELD = "speed"

PORT_STATE_UP = 'U'
PORT_STATE_DOWN = 'D'
PORT_STATE_DISABLED = 'X'

class Portstat(object):
    def __init__(self):
        self.db = swsssdk.SonicV2Connector(host='127.0.0.1')
        self.db.connect(self.db.COUNTERS_DB)
        self.db.connect(self.db.APPL_DB)

    def get_cnstat(self):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = ["0","0","0","0","0","0","0","0","0","0"]
            for counter_name, pos in counter_bucket_dict.iteritems():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = str(int(fields[pos]) + int(counter_data))
            cntr = NStats._make(fields)
            return cntr

        # Get the info from database
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP);
        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if counter_port_name_map is None:
            return cnstat_dict
        for port in natsorted(counter_port_name_map):
            cnstat_dict[port] = get_counters(counter_port_name_map[port])
        return cnstat_dict

    def get_port_speed(self, port_name):
        """
            Get the port speed
        """
        # Get speed from APPL_DB
        full_table_id = PORT_STATUS_TABLE_PREFIX + port_name
        speed = self.db.get(self.db.APPL_DB, full_table_id, PORT_SPEED_FIELD)
        if speed is None:
            speed = PORT_RATE
        else:
            speed = int(speed)//1000
        return speed

    def get_port_state(self, port_name):
        """
            Get the port state
        """
        full_table_id = PORT_STATUS_TABLE_PREFIX + port_name
        admin_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_ADMIN_STATUS_FIELD)
        oper_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_OPER_STATUS_FIELD)
        if admin_state is None or oper_state is None:
             return STATUS_NA
        elif admin_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DISABLED
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_UP:
            return PORT_STATE_UP
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DOWN
        else:
            return STATUS_NA

    def cnstat_print(self, cnstat_dict, intf_list, use_json, print_all, errors_only, rates_only):
        """
            Print the cnstat.
        """
        table = []
        header = None

        for key, data in cnstat_dict.iteritems():
            if key == 'time':
                continue
            if intf_list and key not in intf_list:
                continue
            if print_all:
                header = header_all
                table.append((key, self.get_port_state(key),
                              data.rx_ok, STATUS_NA, STATUS_NA, STATUS_NA, data.rx_err,
                              data.rx_drop, data.rx_ovr,
                              data.tx_ok, STATUS_NA, STATUS_NA, STATUS_NA, data.tx_err,
                              data.tx_drop, data.tx_ovr))
            elif errors_only:
                header = header_errors_only
                table.append((key, self.get_port_state(key),
                              data.rx_err, data.rx_drop, data.rx_ovr,
                              data.tx_err, data.tx_drop, data.tx_ovr))
            elif rates_only:
                header = header_rates_only
                table.append((key, self.get_port_state(key),
                              data.rx_ok, STATUS_NA, STATUS_NA, STATUS_NA,
                              data.tx_ok, STATUS_NA, STATUS_NA, STATUS_NA))
            else:
                header = header_std
                table.append((key, self.get_port_state(key),
                              data.rx_ok, STATUS_NA, STATUS_NA, data.rx_err,
                              data.rx_drop, data.rx_ovr,
                              data.tx_ok, STATUS_NA, STATUS_NA, data.tx_err,
                              data.tx_drop, data.tx_ovr))

        if use_json:
            print table_as_json(table, header)
        else:
            print tabulate(table, header, tablefmt='simple', stralign='right')

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, intf_list, use_json, print_all, errors_only, rates_only):
        """
            Print the difference between two cnstat results.
        """

        table = []
        header = None

        for key, cntr in cnstat_new_dict.iteritems():
            if key == 'time':
                time_gap = cnstat_new_dict.get('time') - cnstat_old_dict.get('time')
                time_gap = time_gap.total_seconds()
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if intf_list and key not in intf_list:
                continue
            port_speed = self.get_port_speed(key)
            if print_all:
                header = header_all
                if old_cntr is not None:
                    table.append((key, self.get_port_state(key),
                                  ns_diff(cntr.rx_ok, old_cntr.rx_ok),
                                  ns_brate(cntr.rx_byt, old_cntr.rx_byt, time_gap),
                                  ns_prate(cntr.rx_ok, old_cntr.rx_ok, time_gap),
                                  ns_util(cntr.rx_byt, old_cntr.rx_byt, time_gap, port_speed),
                                  ns_diff(cntr.rx_err, old_cntr.rx_err),
                                  ns_diff(cntr.rx_drop, old_cntr.rx_drop),
                                  ns_diff(cntr.rx_ovr, old_cntr.rx_ovr),
                                  ns_diff(cntr.tx_ok, old_cntr.tx_ok),
                                  ns_brate(cntr.tx_byt, old_cntr.tx_byt, time_gap),
                                  ns_prate(cntr.tx_ok, old_cntr.tx_ok, time_gap),
                                  ns_util(cntr.tx_byt, old_cntr.tx_byt, time_gap, port_speed),
                                  ns_diff(cntr.tx_err, old_cntr.tx_err),
                                  ns_diff(cntr.tx_drop, old_cntr.tx_drop),
                                  ns_diff(cntr.tx_ovr, old_cntr.tx_ovr)))
                else:
                    table.append((key, self.get_port_state(key),
                                  cntr.rx_ok,
                                  STATUS_NA,
                                  STATUS_NA,
                                  STATUS_NA,
                                  cntr.rx_err,
                                  cntr.rx_drop,
                                  cntr.rx_ovr,
                                  cntr.tx_ok,
                                  STATUS_NA,
                                  STATUS_NA,
                                  STATUS_NA,
                                  cntr.tx_err,
                                  cntr.tx_drop,
                                  cntr.tx_ovr))
            elif errors_only:
                header = header_errors_only
                table.append((key, self.get_port_state(key),
                              ns_diff(cntr.rx_err, old_cntr.rx_err),
                              ns_diff(cntr.rx_drop, old_cntr.rx_drop),
                              ns_diff(cntr.rx_ovr, old_cntr.rx_ovr),
                              ns_diff(cntr.tx_err, old_cntr.tx_err),
                              ns_diff(cntr.tx_drop, old_cntr.tx_drop),
                              ns_diff(cntr.tx_ovr, old_cntr.tx_ovr)))
            elif rates_only:
                header = header_rates_only
                table.append((key, self.get_port_state(key),
                              ns_diff(cntr.rx_ok, old_cntr.rx_ok),
                              STATUS_NA,
                              STATUS_NA,
                              STATUS_NA,
                              ns_diff(cntr.tx_ok, old_cntr.tx_ok),
                              STATUS_NA,
                              STATUS_NA,
                              STATUS_NA))
            else:
                header = header_std
                if old_cntr is not None:
                    table.append((key, self.get_port_state(key),
                                  ns_diff(cntr.rx_ok, old_cntr.rx_ok),
                                  ns_brate(cntr.rx_byt, old_cntr.rx_byt, time_gap),
                                  ns_util(cntr.rx_byt, old_cntr.rx_byt, time_gap),
                                  ns_diff(cntr.rx_err, old_cntr.rx_err),
                                  ns_diff(cntr.rx_drop, old_cntr.rx_drop),
                                  ns_diff(cntr.rx_ovr, old_cntr.rx_ovr),
                                  ns_diff(cntr.tx_ok, old_cntr.tx_ok),
                                  ns_brate(cntr.tx_byt, old_cntr.tx_byt, time_gap),
                                  ns_util(cntr.tx_byt, old_cntr.tx_byt, time_gap),
                                  ns_diff(cntr.tx_err, old_cntr.tx_err),
                                  ns_diff(cntr.tx_drop, old_cntr.tx_drop),
                                  ns_diff(cntr.tx_ovr, old_cntr.tx_ovr)))
                else:
                    table.append((key, self.get_port_state(key),
                                  cntr.rx_ok,
                                  STATUS_NA,
                                  STATUS_NA,
                                  cntr.rx_err,
                                  cntr.rx_drop,
                                  cntr.rx_ovr,
                                  cntr.tx_ok,
                                  STATUS_NA,
                                  STATUS_NA,
                                  cntr.tx_err,
                                  cntr.tx_drop,
                                  cntr.tx_ovr))

        if use_json:
            print table_as_json(table, header)
        else:
            print tabulate(table, header, tablefmt='simple', stralign='right')


def main():
    parser  = argparse.ArgumentParser(description='Display the ports state and counters',
                                      version='1.0.0',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Port state: (U)-Up (D)-Down (X)-Disabled
Examples:
  portstat -c -t test
  portstat -t test
  portstat -d -t test
  portstat -e 
  portstat
  portstat -r
  portstat -R
  portstat -a
  portstat -p 20
""")

    parser.add_argument('-a', '--all', action='store_true', help='Display all the stats counters')
    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats, either the uid or the specified tag')
    parser.add_argument('-D', '--delete-all', action='store_true', help='Delete all saved stats')
    parser.add_argument('-e', '--errors', action='store_true', help='Display interface errors')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-r', '--raw', action='store_true', help='Raw stats (unmodified output of netstat)')
    parser.add_argument('-R', '--rate', action='store_true', help='Display interface rates')
    parser.add_argument('-t', '--tag', type=str, help='Save stats with name TAG', default=None)
    parser.add_argument('-p', '--period', type=int, help='Display stats over a specified period (in seconds).', default=0)
    parser.add_argument('-i', '--interface', type=str, help='Display stats for interface lists.', default=None)
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_saved_stats = args.delete
    delete_all_stats = args.delete_all
    errors_only = args.errors
    rates_only = args.rate
    use_json = args.json
    raw_stats = args.raw
    tag_name = args.tag
    uid = str(os.getuid())
    wait_time_in_seconds = args.period
    print_all = args.all
    intf_fs = args.interface
    if tag_name is not None:
        cnstat_file = uid + "-" + tag_name
    else:
        cnstat_file = uid

    cnstat_dir = "/tmp/portstat-" + uid
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        for file in os.listdir(cnstat_dir):
            os.remove(cnstat_dir + "/" + file)

        try:
            os.rmdir(cnstat_dir)
            sys.exit(0)
        except IOError as e:
            print e.errno, e
            sys.exit(e)

    if delete_saved_stats:
        try:
            os.remove(cnstat_fqn_file)
        except IOError as e:
            if e.errno != ENOENT:
                print e.errno, e
                sys.exit(1)
        finally:
            if os.listdir(cnstat_dir) == []:
                os.rmdir(cnstat_dir)
            sys.exit(0)

    intf_list = parse_interface_in_filter(intf_fs)

    portstat = Portstat()
    # The cnstat_dict just give an ordered dict of all output. 
    cnstat_dict = portstat.get_cnstat()

    # Now decide what information to display
    if raw_stats:
        portstat.cnstat_print(cnstat_dict, intf_list, use_json, print_all, errors_only, rates_only)
        sys.exit(0)

    # At this point, either we'll create a file or open an existing one.
    if not os.path.exists(cnstat_dir):
        try:
            os.makedirs(cnstat_dir)
        except IOError as e:
            print e.errno, e
            sys.exit(1)


    if save_fresh_stats:
        try:
            pickle.dump(cnstat_dict, open(cnstat_fqn_file, 'w'))
        except IOError as e:
            sys.exit(e.errno)
        else:
            print "Cleared counters"
            sys.exit(0)

    if wait_time_in_seconds == 0:
        cnstat_cached_dict = OrderedDict()
        if os.path.isfile(cnstat_fqn_file):
            try:
                cnstat_cached_dict = pickle.load(open(cnstat_fqn_file, 'r'))
                print "Last cached time was " + str(cnstat_cached_dict.get('time'))
                portstat.cnstat_diff_print(cnstat_dict, cnstat_cached_dict, intf_list, use_json, print_all, errors_only, rates_only)
            except IOError as e:
                print e.errno, e
        else:
            if tag_name:
                print "\nFile '%s' does not exist" % cnstat_fqn_file
                print "Did you run 'portstat -c -t %s' to record the counters via tag %s?\n" % (tag_name, tag_name)
            else:
                portstat.cnstat_print(cnstat_dict, intf_list, use_json, print_all, errors_only, rates_only)
    else:
        #wait for the specified time and then gather the new stats and output the difference.
        time.sleep(wait_time_in_seconds)
        print "The rates are calculated within %s seconds period" % wait_time_in_seconds
        cnstat_new_dict = portstat.get_cnstat()
        portstat.cnstat_diff_print(cnstat_new_dict, cnstat_dict, intf_list, use_json, print_all, errors_only, rates_only)

if __name__ == "__main__":
    main()
